﻿#pragma kernel SampleSDF
#pragma kernel CalculateSurface
#pragma kernel Triangulate
#pragma kernel Smoothing

#pragma kernel FixArgsBuffer
#pragma kernel FillIndirectDrawBuffer

#include "MathUtils.cginc"
#include "Bounds.cginc"
#include "GridUtils.cginc"
#include "FluidChunkDefs.cginc"
#include "FluidKernels.cginc"
#include "SolverParameters.cginc"
#include "NormalCompression.cginc"

/*  
 *  y         z
 *  ^        /     
 *  |
 *    6----7
 *   /|   /|
 *  2----3 |
 *  | 4--|-5
 *  |/   |/
 *  0----1   --> x
 * 
 */

 /*
 *  static Vector2Int[] cubeEdges =
 *  {
 *      new Vector2Int(7,3),  0
 *      new Vector2Int(7,5),  1
 *      new Vector2Int(7,6),  2
 *
 *      new Vector2Int(6,4),  3 // x 
 *      new Vector2Int(4,5),  4
 *      new Vector2Int(5,1),  5     
 *
 *      new Vector2Int(1,3),  6 // y   
 *      new Vector2Int(3,2),  7    
 *      new Vector2Int(2,6),  8    
 *
 *      new Vector2Int(4,0),  9 // z   
 *      new Vector2Int(2,0),  10  
 *      new Vector2Int(1,0)   11   
 *  };
 */
 
static const int4 faceNeighborEdges[] =
{
    int4(0,1,5,6),    
    int4(0,2,7,8),    
    int4(4,5,9,11), 
    int4(3,8,9,10),
    int4(6,7,10,11),  
    int4(1,2,3,4)
};

static const float3 corners[] =
{
    float3(0, 0, 0), // 0
    float3(1, 0, 0), // 1
    float3(0, 1, 0), // 2
    float3(1, 1, 0), // 3
    float3(0, 0, 1), // 4
    float3(1, 0, 1), // 5
    float3(0, 1, 1), // 6
    float3(1, 1, 1)  // 7
};

static const int3 quadNeighborIndices[] =
{
    int3(1, 3, 2), // x   
    int3(4, 5, 1), // y   
    int3(2, 6, 4), // z  
};

static const int oppositeFaces[] =
{
    7, //0
    6, //1
    5, //2
    4, //3
    3, //4
    2, //5
    1 //6
};

static const int3 quadWindingOrder[] = {
    int3(0, 1 ,2),
    int3(2, 1 ,0)
};

struct indirectDrawIndexedArgs
{
    uint indexCountPerInstance;
    uint instanceCount;
    uint startIndex;
    uint baseVertexIndex;
    uint startInstance;
};

// particle grid data:
StructuredBuffer<aabb> solverBounds;
StructuredBuffer<uint> cellOffsets;    // start of each cell in the sorted item array.
StructuredBuffer<uint> cellCounts;     // number of item in each cell.
StructuredBuffer<int> gridHashToSortedIndex; 

StructuredBuffer<float4> sortedPositions;
StructuredBuffer<float4> sortedVelocities;
StructuredBuffer<float4> sortedPrincipalRadii;
StructuredBuffer<float4> sortedColors;
StructuredBuffer<quaternion> sortedOrientations;

// voxel data:
StructuredBuffer<int3> chunkCoords; // for each chunk, spatial coordinates.
StructuredBuffer<keyvalue> hashtable; // size: maxChunks entries.
RWStructuredBuffer<uint> voxelToVertex; // for each voxel, index into the vertices array in case the voxel spawns a vertex, INVALID otherwise.
RWStructuredBuffer<uint> vertexAdjacency; // indices of adjacent face voxels for each voxel.
RWStructuredBuffer<float4> voxelVelocities; // for each voxel, fluid velocity value. We are reusing the same ComputeBuffer used for vertexAdjacency.

// edge LUTs
StructuredBuffer<int2> edges;
StructuredBuffer<int> edgeTable;

// mesh data:
RWStructuredBuffer<float4> verts; 
RWStructuredBuffer<float4> outputVerts;
RWStructuredBuffer<float4> colors; 
RWStructuredBuffer<float4> velocities; 
RWStructuredBuffer<int> quads;

RWStructuredBuffer<uint> dispatchBuffer; 
RWStructuredBuffer<uint> dispatchBuffer2; 
RWStructuredBuffer<indirectDrawIndexedArgs> indirectBuffer; 

uint currentBatch;

float isosurface;
uint descentIterations;
float descentIsosurface;
float descentSpeed;
float smoothing;
float bevel;
uint dispatchMultiplier;
uint countMultiplier;
uint instanceCount;

uint voxelCoordToOffset(int3 coord)
{
    if (mode == 1)
        return (int)EncodeMorton2((uint2)coord.xy);
    else return (int)EncodeMorton3((uint3)coord.xyz);
}

uint LookupChunk(int3 coords)
{
    uint key = VoxelID(coords);
    uint slot = hash(key);

    for (uint i = 0; i < maxChunks; ++i) // at most, check the entire table.
    {
        if (hashtable[slot].key == key)
        {
            return hashtable[slot].handle;
        }
        if (hashtable[slot].key == INVALID)
        {
            return INVALID;
        }

        slot = (slot + 1) % maxChunks;
    }
    return INVALID;
}

uint GetVoxelIndex(int3 chunkCoords, int3 voxelCoords, int voxelsInChunk)
{
    int3 mask = voxelCoords < 0 ? -1 : voxelCoords / chunkResolution; 
    uint chunk = LookupChunk(chunkCoords + mask);
   
    return chunk == INVALID ? INVALID : chunk * voxelsInChunk + voxelCoordToOffset(nfmod(voxelCoords, chunkResolution));
}

[numthreads(128, 1, 1)]
void SampleSDF (uint3 id : SV_DispatchThreadID)  // one thread per voxel,
{
    uint i = id.x;
    if (i >= dispatchBuffer[3]) return;
    
    int voxelsInChunk = (int)pow(chunkResolution, 3 - mode); // 64 voxels in 3D, 16 in 2D.

    // calculate chunk index:
    int chunkIndex = i / voxelsInChunk;   

    // get offset of voxel within chunk:
    int cornerOffset = i - chunkIndex * voxelsInChunk;
    int3 voxelCoords;

    if (mode == 1)
        voxelCoords = (int3)DecodeMorton2((uint)cornerOffset);
    else
        voxelCoords = (int3)DecodeMorton3((uint)cornerOffset);

    int3 cornerCoords= chunkCoords[chunkIndex] * chunkResolution + voxelCoords;

    // calculate sampling position:
    float3 samplePos = chunkGridOrigin + cornerCoords * voxelSize;
    
    float dist = isosurface;
    float4 color = FLOAT4_ZERO;
    float4 velocity = FLOAT4_ZERO;

    for (uint m = 1; m <= levelPopulation[0]; ++m)
    {
        uint l = levelPopulation[m];
        float cellSize = CellSizeOfLevel(l);

        int3 cell = floor((samplePos - solverBounds[0].min_.xyz) / cellSize); // TODO: not necessary to subtract solverBounds?
        int3 minCell = cell - 1;
        int3 maxCell = cell + 1;

        if (mode == 1)
            minCell[2] = maxCell[2] = 0;
       
        for (int x = minCell[0]; x <= maxCell[0]; ++x)
        {
            for (int y = minCell[1]; y <= maxCell[1]; ++y)
            {
                for (int z = minCell[2]; z <= maxCell[2]; ++z)
                {
                    int4 neighborCoord = int4(x, y, z, l);
                    int cellIndex = gridHashToSortedIndex[GridHash(neighborCoord)];
                    uint n = cellOffsets[cellIndex]; 
                    uint end = n + cellCounts[cellIndex];

                    for (;n < end; ++n)
                    {
                        float3 radii = sortedPrincipalRadii[n].xyz;

                        // due to hash collisions, two neighboring cells might map to the same
                        // hash bucket, and we'll add the same set of particles twice to the neighbors list.
                        // So we only consider particles that have the same spatial coordinates as the cell.
                        uint level = GridLevelForSize(radii.x);
                        float cellSize = CellSizeOfLevel(level);
                        int4 particleCoord = int4(floor((sortedPositions[n].xyz - solverBounds[0].min_.xyz)/ cellSize).xyz,level);

                        if (any (particleCoord - neighborCoord))
                            continue;

                        float3 normal = samplePos - sortedPositions[n].xyz;  
                        if (mode == 1)
                            normal[2] = 0;

                        // only update distance if within anisotropic kernel radius:
                        float maxDistance = radii.x + voxelSize * 1.42f; 
                        float r = dot(normal, normal);
                        if (r <= maxDistance * maxDistance)
                        {
                            normal = rotate_vector(q_conj(sortedOrientations[n]), normal.xyz) / radii;
                            float d = length(normal) * radii.x;

                            // sortedPositions.w is volume (1/normalized density):
                            float w = sortedPositions[n].w * Poly6(d,radii.x);
                            
                            dist -= w;
                            
                            // tigther smoothing kernel for color and velocities:
                            float w2 = 1-saturate(r / (radii.x * radii.x));
                            color += float4(sortedColors[n].xyz * w2, w2);
                            velocity += sortedVelocities[n] * w2; // 4th component is length(angularVel), vorticity intensity.
                        }
                        
                    }
                }
            }
        }
    }
    
    verts[i].x = dist;
    verts[i].y = PackFloatRGBA(color/color.w);

    voxelVelocities[i] = velocity / color.w;
}

float EvaluateSDF (float4 distancesA, float4 distancesB, in float3 nPos, out float3 normal)
{
    // trilinear interpolation of distance:
    float4 x = distancesA + (distancesB - distancesA) * nPos[0];
    float2 y = x.xy + (x.zw - x.xy) * nPos[1];

    // gradient estimation:
    // x == 0
    float2 a = distancesA.xy + (distancesA.zw - distancesA.xy) * nPos[1];
    float x0 = a[0] + (a[1] - a[0]) * nPos[2];

    // x == 1
    a = distancesB.xy + (distancesB.zw - distancesB.xy) * nPos[1];
    float x1 = a[0] + (a[1] - a[0]) * nPos[2];

    // y == 0
    float y0 = x[0] + (x[1] - x[0]) * nPos[2];

    // y == 1
    float y1 = x[2] + (x[3] - x[2]) * nPos[2];

    normal = normalize(float3(x1 - x0, y1 - y0, y[1] - y[0]));
    return y[0] + (y[1] - y[0]) * nPos[2];
}

[numthreads(128, 1, 1)]
void CalculateSurface (uint3 id : SV_DispatchThreadID) // once per voxel.
{
    uint i = id.x;
    if (i >= dispatchBuffer[3]) return;

    int voxelsInChunk = (int)pow(chunkResolution, 3 - mode); // 64 voxels in 3D, 16 in 2D.
    int verticesPerVoxel = mode == 1 ? 4 : 8; // 8 vertices in 3D, 4 in 2D.
    float3 dimensionMask = mode == 1 ? float3(1, 1, 0) : float3(1, 1, 1);
    int edgeCount = mode == 1 ? 4 : 12;

    // initialize voxel with invalid vertex:
    voxelToVertex[i] = INVALID;

    // calculate chunk index:
    int chunkIndex = i / voxelsInChunk;   
    
    // get offset of voxel within chunk:
    int voxelOffset = i - chunkIndex * voxelsInChunk; 
    int3 voxelCoords;

    if (mode == 1)
        voxelCoords = (int3)DecodeMorton2((uint)voxelOffset);
    else
        voxelCoords = (int3)DecodeMorton3((uint)voxelOffset);
   
    // get samples at voxel corners:
    float samples[8];
    float4 vcolor[8];
    float4 vvelo[8];
    uint cornerMask = 0;

    // initialize all samples to zero (last 4 samples aren't written to in 2D).
    int j = 0;
    for (j = 0; j < 8; ++j)
        samples[j] = 0;

    for (j = 0; j < verticesPerVoxel; ++j)
    {
        uint v = GetVoxelIndex(chunkCoords[chunkIndex], voxelCoords + corners[j], voxelsInChunk);

        if (v == INVALID)
            return;
        else
        {
            samples[j] = verts[v].x;
            vcolor[j] = UnpackFloatRGBA(verts[v].y);
            vvelo[j] = voxelVelocities[v];
            cornerMask |= samples[j] >= 0 ? (1 << j) : 0;
        }
    }

    // store cornerMask:
    verts[i].z = asfloat(cornerMask);

    int edgeMask = edgeTable[cornerMask];

    // if the voxel does not intersect the surface, return:
    if ((mode == 1 && cornerMask == 0xf) ||
        (mode == 0 && edgeMask == 0))
        return;

    // calculate vertex position using edge crossings:
    float3 normalizedPos = float3(0,0,0);
    float4 velocity = FLOAT4_ZERO;
    float4 color = FLOAT4_ZERO;
    int intersections = 0;
    
    for (j = 0; j < edgeCount; ++j)
    {
        if((edgeMask & (1 << j)) == 0) 
            continue;
           
        int2 e = edges[j];
        float t = -samples[e.x] / (samples[e.y] - samples[e.x]);

        normalizedPos += lerp(corners[e.x],corners[e.y],t);
        color +=         lerp(vcolor[e.x], vcolor[e.y], t);
        velocity +=      lerp(vvelo[e.x],  vvelo[e.y],  t);
        intersections++;
    }

    // intersections will always be > 0 in 3D:
    if (intersections > 0)
    {
        normalizedPos /= intersections;
        color /= intersections;
        velocity /= intersections;
    }
    else
    {
        normalizedPos = float3(0.5f, 0.5f, -bevel);
        color = (vcolor[0] + vcolor[1] + vcolor[2] + vcolor[3]) * 0.25f;
        velocity = (vvelo[0] + vvelo[1] + vvelo[2] + vvelo[3]) * 0.25f;
    }

    float4 distancesA = float4(samples[0], samples[4], samples[2], samples[6]);
    float4 distancesB = float4(samples[1], samples[5], samples[3], samples[7]);
    
    // gradient descent:
    float3 normal;
    for(uint k = 0; k < descentIterations; ++k)
    {
        float d = EvaluateSDF(distancesA, distancesB, normalizedPos, normal);
        normalizedPos -= descentSpeed * normal * (d + isosurface + descentIsosurface);
    }
    
    // final normal evaluation:
    EvaluateSDF(distancesA, distancesB, normalizedPos, normal);

    // modify normal in 2D mode:
    if (mode)
       normal = lerp(float3(0,0,-1),float3(normal.xy,-normal.z),bevel); // no bevel, flat normals

    // Append vertex:
    InterlockedAdd(dispatchBuffer[4],1,voxelToVertex[i]);
    verts[voxelToVertex[i]].w = i;

    float3 voxelCorner = chunkGridOrigin + (float3)(chunkCoords[chunkIndex] * chunkResolution + voxelCoords) * voxelSize;
    outputVerts[voxelToVertex[i]] = float4(voxelCorner * dimensionMask + normalizedPos * voxelSize, encode(normal));
    colors[voxelToVertex[i]] = color;
    velocities[voxelToVertex[i]] = velocity;
}

[numthreads(128, 1, 1)]
void Triangulate (uint3 id : SV_DispatchThreadID) 
{
    uint v0 = id.x;
    if (v0 >= dispatchBuffer[3]) return;

    int voxelsInChunk = (int)pow(chunkResolution, 3 - mode); // 64 voxels in 3D, 16 in 2D.
    int quadCount = 3 - mode * 2; // 3 quads in 3D, 1 in 2D.
    int adjacentCount = 6 - mode * 2; // 6 adjacent voxels in 3D, 4 in 2D.
    
    // get index of the voxel that spawned this vertex:
    uint i = verts[v0].w;
        
    // calculate chunk index and look up coordinates:
    int chunkIndex = i / voxelsInChunk;        
    
    // get offset of voxel within chunk:
    int voxelOffset = i - chunkIndex * voxelsInChunk;
    int3 voxelCoords;

    if (mode == 1)
        voxelCoords = (int3)DecodeMorton2((uint)voxelOffset);
    else
        voxelCoords = (int3)DecodeMorton3((uint)voxelOffset);

    uint cornerMask = asuint(verts[i].z);
    int edgeMask = edgeTable[cornerMask];
    
    // get winding order using last bit of cornermask, which indicates corner sign:
    // in 2D, cornerMask >> 7 is always 0, so we get the second winding order.
    int3 windingOrder = (cornerMask >> 7) ? quadWindingOrder[0] : quadWindingOrder[1];
    
    // Retrieve adjacent voxels:
    int j;
    uint adjacent[6];
    for (j = 0; j < adjacentCount; ++j)
        adjacent[j] = GetVoxelIndex(chunkCoords[chunkIndex], voxelCoords + corners[j + 1], voxelsInChunk);
        
    // Iterate over all potential quads, append those needed:
    for (j = 0; j < quadCount; ++j)
    {
        // if the edge is not crossing the surface, skip it (3D only)
        if (mode == 0 && (edgeMask & (1 << j)) == 0)
            continue;

        // calculate final neighbor indices:
        uint3 neighbors = uint3(quadNeighborIndices[j][windingOrder[0]]-1,
                                quadNeighborIndices[j][windingOrder[1]]-1,
                                quadNeighborIndices[j][windingOrder[2]]-1);

        // get vertex indices for all voxels involved: 
        uint v1 = voxelToVertex[adjacent[neighbors[0]]];
        uint v2 = voxelToVertex[adjacent[neighbors[1]]];
        uint v3 = voxelToVertex[adjacent[neighbors[2]]];

        // if any of the vertices is invalid, skip the quad:
        if (v1 == INVALID || v2 == INVALID || v3 == INVALID)
            continue;

        // append a new quad:
        uint baseIndex;
        InterlockedAdd(dispatchBuffer2[4],1,baseIndex);
        baseIndex *= 6;

        // flip edge if necessary, to always use the shortest diagonal:
        float3 diag1 = outputVerts[v0].xyz - outputVerts[v2].xyz;
        float3 diag2 = outputVerts[v1].xyz - outputVerts[v3].xyz;
        if (dot(diag1,diag1) > dot(diag2,diag2) * 1.1)
        {
            quads[baseIndex] = v1;
            quads[baseIndex+1] = v2;
            quads[baseIndex+2] = v3;

            quads[baseIndex+3] = v0;
            quads[baseIndex+4] = v1;
            quads[baseIndex+5] = v3;
        }
        else
        {
            quads[baseIndex] = v0;
            quads[baseIndex+1] = v1;
            quads[baseIndex+2] = v2;

            quads[baseIndex+3] = v3;
            quads[baseIndex+4] = v0;
            quads[baseIndex+5] = v2;
        }
    }
    
    // Move adjacent voxel in Z axis to last position, so that 2D adjacent voxels are the first 4.
    adjacent[5] = adjacent[3];
    adjacent[2] = GetVoxelIndex(chunkCoords[chunkIndex], voxelCoords + int3(0, -1, 0), voxelsInChunk);
    adjacent[3] = GetVoxelIndex(chunkCoords[chunkIndex], voxelCoords + int3(-1, 0, 0), voxelsInChunk);
    adjacent[4] = GetVoxelIndex(chunkCoords[chunkIndex], voxelCoords + int3(0, 0, -1), voxelsInChunk);
    
    // initialize vertex adjacency to INVALID.
    for (j = 0; j < 6; ++j)
        vertexAdjacency[v0*6 + j] = INVALID;

    // Determine adjacent surface voxels for smoothing:
    bool isAdjacent;
    for (j = 0; j < adjacentCount; ++j)
    {
        if (adjacent[j] != INVALID)
        {
            // adjacent if this does not intersect the surface or both intersect the surface.
            isAdjacent = (edgeMask == 0 || edgeTable[asuint(verts[adjacent[j]].z)] != 0) &&

            // in 3D mode, it should also intersect any of the face edges to be considered adjacent:
                         (mode == 1 || any(edgeMask & (1 << faceNeighborEdges[j])));

            vertexAdjacency[v0 * 6 + j] = isAdjacent ? voxelToVertex[adjacent[j]] : INVALID;
        }
    }
} 

[numthreads(128, 1, 1)]
void Smoothing (uint3 id : SV_DispatchThreadID) 
{
    uint thread = id.x;
    if (thread >= dispatchBuffer[3]) return;
    
    float3 n = decode(verts[thread].w);

    float4 coord = float4(verts[thread].xyz,1);
    float4 norm = float4(n,1);

    for (int j = 0; j < 6; ++j)
    {
        uint v = vertexAdjacency[thread*6 + j];
        if (v != INVALID)
        {
            coord += float4(verts[v].xyz,1);
            norm += float4(decode(verts[v].w),1);
        }
    }
    
    coord.xyz /= coord.w;
    norm.xyz /= norm.w;

    float3 v = lerp(verts[thread].xyz, coord.xyz, smoothing);
    n = normalize(lerp(n,norm.xyz, smoothing));

    outputVerts[thread] = float4(v, encode(n));

} 

[numthreads(1, 1, 1)]
void FixArgsBuffer (uint3 id : SV_DispatchThreadID) 
{
    dispatchBuffer[3] = dispatchBuffer[4] * dispatchMultiplier; 
    dispatchBuffer[0] = dispatchBuffer[3] / 128 + 1;
    dispatchBuffer[4] *= countMultiplier; // used to zero out fourth component if needed.
}

[numthreads(1, 1, 1)]
void FillIndirectDrawBuffer (uint3 id : SV_DispatchThreadID) 
{
    indirectDrawIndexedArgs a;

    a.indexCountPerInstance = dispatchBuffer[3] * 6; // number of quads * 6
    a.instanceCount = instanceCount;
    a.startIndex = 0; 
    a.baseVertexIndex = 0;
    a.startInstance = 0;
    indirectBuffer[0] = a;
}

